Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage#13406
Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage#13406akshan-main wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
The profiling was done with 2 steps, but this sync happens every transformer forward call, so at 20 inference steps, this eliminates ~1.5s of CPU-GPU sync overhead per run. Under torch.compile the impact is larger since GPU queues are deeper(each sync stalls longer) (80ms vs 76ms in eager). |
|
oh and this fix applies to all QwenImage variants (Edit, EditPlus, Layered) since they share the same transformer |
|
@akshan-main thanks for this! In the second plot, could you tell which one of the blocks the reported duration belongs to? |
|
the selected slice in after image is the transformer_forward user_annotation itself (~439ms), wrapping the full QwenImageTransformer2DModel.forward. I am highlighting a specific sub-block showing where the 76ms cudaStreamSynchronize used to sit (in the before screenshot) is gone. |
|
~439ms is for entire transformer_forward block |
What does this PR do?
Part of #13401
QwenEmbedRope.forward()copiespos_freqsandneg_freqsfrom CPU to GPU via.to(device)on every transformer forward call. These tensors are fixed at init and never change, so the repeated transfer triggers an unnecessarycudaStreamSynchronize(~76ms each).Added
_get_device_freqs()that caches the GPU copy on first call. Applied to bothQwenEmbedRopeandQwenEmbedLayer3DRope.(
register_buffercan't be used here because it drops the imaginary part of complex tensors)Profiling (A100 80GB, eager, 2 steps, 1024x1024)
Before (76ms cudaStreamSynchronize inside transformer_forward):
After (no sync gap):
Profiled with the tooling from #13356. Reproduction notebook.
Part of #13401
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul @dg845